#include <stdlib.h>
#include <unistd.h>
#include <string.h>
#include <arpa/inet.h>
#include <errno.h>
#include <netdb.h>

#include "callback.h"
#include "socks5.h"
#include "conn.h"

int strtosockaddr(const char *src, void *addrptr)
{
    int ret;

    struct sockaddr_storage *storage = (struct sockaddr_storage *)addrptr;

    struct sockaddr_in addr4;
    ret = inet_pton(AF_INET, src, &(addr4.sin_addr));
    if (ret > 0) {
        storage->ss_family = AF_INET;
        struct sockaddr_in *addr = (struct sockaddr_in *)addrptr;
        memcpy(&addr->sin_addr, &addr4.sin_addr, sizeof(addr4.sin_addr));
        return ret;
    }

    struct sockaddr_in6 addr6;
    ret = inet_pton(AF_INET6, src, &(addr6.sin6_addr));
    if (ret > 0) {
        storage->ss_family = AF_INET6;
        struct sockaddr_in6 *addr = (struct sockaddr_in6 *)addrptr;
        memcpy(&addr->sin6_addr, &addr6.sin6_addr, sizeof(addr6.sin6_addr));
        return ret;
    }

    return -1;
}

void remote_connect(hio_t* io)
{
    char localaddrstr[SOCKADDR_STRLEN] = {0};
    char peeraddrstr[SOCKADDR_STRLEN] = {0};
    LOGD("remote_connect connfd=%d [%s] => [%s]", hio_fd(io),
	SOCKADDR_STR(hio_localaddr(io), localaddrstr),
	SOCKADDR_STR(hio_peeraddr(io), peeraddrstr));
    struct tunnel_conn *conn = hevent_userdata(io);	
    if (conn == NULL) {
        LOGE("Error buf is null");
        return;
    }

    if (SOCKS5_CONN_STAGE_CONNECTING == conn->stage) {
        LOGD("remote connected, fd: [%d], stage: [%d]",  hio_fd(io), conn->stage);

        struct socks5_response reply;
        reply.ver = SOCKS5_VERSION;
        reply.rep = SOCKS5_RESPONSE_SUCCESS;
        reply.addrtype = conn->remote_conn.addrtype;
        int remotefd = hio_fd(io);

        struct sockaddr_storage storage;
        socklen_t len = sizeof(storage);
        if (getpeername(remotefd, (struct sockaddr *)&storage, &len) < 0) {
            LOGW("getpeername(%s:%d) fail, errno: [%d]", conn->remote_conn.hostname, conn->remote_conn.port, errno);
            // something wrong
            reply.rep = SOCKS5_RESPONSE_SERVER_FAILURE;
		
			hio_write(conn->client, (char *)&reply, sizeof(reply));
            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CLOSING);
        } else {
            // connected
            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CONNECTED);
			buffer_t *data = buffer_new(SOCKS5_DEFAULT_BUFFER_SIZE);
			buffer_concat(data, (char *)&reply, sizeof(reply));
            buffer_concat(data, conn->remote_conn.bndaddr->data, buffer_len(conn->remote_conn.bndaddr));
            LOGI("remote connected host=%s, port=%d", conn->remote_conn.hostname, conn->remote_conn.port);

			hio_write(conn->client, (char *)data->data, buffer_len(data));
        }
    }
    hio_read(io);
}

void remote_close(hio_t* io)
{
    LOGD("remote_close fd=%d error=%d", hio_fd(io), hio_error(io));
    //if (!hio_error(io) && hevent_userdata(io))
    //    tunnel_conn_close((struct tunnel_conn *)hevent_userdata(io));
}

void remote_recv(hio_t* io, void* buf, int readbytes)
{
    char localaddrstr[SOCKADDR_STRLEN] = {0};
    char peeraddrstr[SOCKADDR_STRLEN] = {0};
    LOGD("remote_recv fd=%d readbytes=%d", hio_fd(io), readbytes);
    LOGD("[%s] <=> [%s]",
            SOCKADDR_STR(hio_localaddr(io), localaddrstr),
            SOCKADDR_STR(hio_peeraddr(io), peeraddrstr));
    //LOGD("< %.*s", readbytes, (char*)buf);
    char *buff = (char *)buf;
    struct tunnel_conn *conn = hevent_userdata(io);	
    if (conn == NULL || buff == NULL) {
        LOGD("Error buf is null");
        return;
    }
    hio_write(conn->client, buf, readbytes);
}

void remote_send(hio_t* io, void* buf, int writebytes)
{
    char *buff = (char *)buf;
    struct tunnel_conn *conn = hevent_userdata(io);	
    if (conn == NULL || buff == NULL) {
        LOGD("Error buf is null\n");
        return;
    }

    LOGD("remote_send fd=%d writebytes=%d", hio_fd(io), writebytes);
    if (SOCKS5_CONN_STAGE_CONNECTING == conn->stage) {
        LOGD("remote connected, fd: [%d], stage: [%d]",  hio_fd(io), conn->stage);

        struct socks5_response reply;
        reply.ver = SOCKS5_VERSION;
        reply.rep = SOCKS5_RESPONSE_SUCCESS;
        reply.addrtype = conn->remote_conn.addrtype;
        int remotefd = hio_fd(io);

        struct sockaddr_storage storage;
        socklen_t len = sizeof(storage);
        if (getpeername(remotefd, (struct sockaddr *)&storage, &len) < 0) {
            LOGW("getpeername(%s:%d) fail, errno: [%d]", conn->remote_conn.hostname, conn->remote_conn.port, errno);
            // something wrong
            reply.rep = SOCKS5_RESPONSE_SERVER_FAILURE;
		
            hio_write(conn->client, (char *)&reply, sizeof(reply));
            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CLOSING);
        } else {
            // connected
            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CONNECTED);
			buffer_t *data = buffer_new(SOCKS5_DEFAULT_BUFFER_SIZE);
			buffer_concat(data, (char *)&reply, sizeof(reply));
            buffer_concat(data, conn->remote_conn.bndaddr->data, buffer_len(conn->remote_conn.bndaddr));
            LOGI("remote connected host=%s, port=%d", conn->remote_conn.hostname, conn->remote_conn.port);

            hio_write(conn->client, data->data, buffer_len(data));
        }
        return;
    }	
}

void on_close(hio_t* io)
{
    LOGD("on_close fd=%d error=%d", hio_fd(io), hio_error(io));
    if (!hio_error(io) && hevent_userdata(io))
        tunnel_conn_close((struct tunnel_conn *)hevent_userdata(io));
}

void on_recv(hio_t* io, void* buf, int readbytes)
{
    char localaddrstr[SOCKADDR_STRLEN] = {0};
    char peeraddrstr[SOCKADDR_STRLEN] = {0};
    LOGD("on_recv fd=%d readbytes=%d", hio_fd(io), readbytes);
    LOGD("[%s] <=> [%s]",
            SOCKADDR_STR(hio_localaddr(io), localaddrstr),
            SOCKADDR_STR(hio_peeraddr(io), peeraddrstr));

    char *buff = (char *)buf;
    struct tunnel_conn *conn = hevent_userdata(io);	
    if (conn == NULL || buff == NULL) {
        LOGD("Error buf is null");
        return;
    }

    struct socks5_server *server = conn->server;
    switch (conn->stage) {
	    case SOCKS5_CONN_STAGE_EXMETHOD: {
	        struct socks5_method_req *method_req;
	        method_req = (struct socks5_method_req *)buff;
	        // verify version
	        if (SOCKS5_VERSION != method_req->ver) {
	            LOGD("invalid socks5 version: [%d]", method_req->ver);
	            goto _close_conn;
	        }
	        if (readbytes < (method_req->nmethods + 2)) {
	            LOGD("need more data");
	            // wating more data
	            return;
	        }

	        struct socks5_method_res reply = {
	            SOCKS5_VERSION,
	            SOCKS5_AUTH_NOACCEPTABLE
	        };

	        int i;
	        for (i = 0; i < method_req->nmethods; i++) {
	            LOGD("auth methods: [%d]", method_req->methods[i]);
	            if (server->auth_method == method_req->methods[i]) {
	                reply.method = server->auth_method;
	                conn->method = reply.method;
	            }
	        }

	        LOGD("auth method: [%d]", reply.method);

                hio_write(io, (char *)&reply, sizeof(reply));
	        if (SOCKS5_AUTH_NOACCEPTABLE == reply.method) {
	            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CLOSING);
	        }

	        // reset recv buffer
	        break;
	    }
	    case SOCKS5_CONN_STAGE_USERNAMEPASSWORD: {
	        struct socks5_userpass_req req;
	        memset(&req, 0, sizeof(struct socks5_userpass_req));

	        req.ver = *buff;
	        if (SOCKS5_AUTH_USERNAMEPASSWORD_VER != req.ver) {
	            LOGD("invalid socks5 version: [%d]", *buff);
	            goto _close_conn;
	        }

	        if (readbytes < 2) {
	            LOGW("no username len, need more data");
	            // wating more data
	            return;
	        }
	        req.ulen = *(buff + 1);
	        if (readbytes < (2 + req.ulen)) {
	            LOGW("no username, need more data");
	            // wating more data
	            return;
	        }
	        memcpy(req.username, buff + 2, req.ulen);

	        if (readbytes < (req.ulen + 3)) {
	            LOGW("no password len, need more data");
	            // wating more data
	            return;
	        }
	        req.plen = *(buff + req.ulen + 2);
	        if (readbytes < (req.ulen + req.plen + 3)) {
	            LOGW("no password, need more data");
	            // wating more data
	            return;
	        }
	        memcpy(req.password, buff + req.ulen + 3, req.plen);

	        LOGD("username/password: [%s]/[%s]", req.username, req.password);

	        struct socks5_userpass_res res = {
	            SOCKS5_AUTH_USERNAMEPASSWORD_VER,
	            SOCKS5_AUTH_USERNAMEPASSWORD_STATUS_FAIL
	        };

	        if (server->ulen == req.ulen &&
	            server->plen == req.plen &&
	            0 == memcmp(&server->username, &req.username, req.ulen) &&
	            0 == memcmp(&server->password, &req.password, req.ulen)) {
	            res.status = SOCKS5_AUTH_USERNAMEPASSWORD_STATUS_OK;
	        }

	        if (SOCKS5_AUTH_USERNAMEPASSWORD_STATUS_FAIL == res.status) {
	            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CONNECTING);
	        }

		hio_write(io, (char *)&res, sizeof(struct socks5_userpass_res));
	        break;
	    }
	    case SOCKS5_CONN_STAGE_EXHOST: {
	        struct socks5_request *req = (struct socks5_request *)buff;
	        // verify version
	        if (SOCKS5_VERSION != req->ver) {
	            LOGD("invalid socks5 version: [%d]", req->ver);
	            goto _close_conn;
	        }

	        // wait more data
	        if (readbytes < sizeof(struct socks5_request)) {
	            LOGW("need more data");
	            return;
	        }

	        struct socks5_response reply = {
	            SOCKS5_VERSION,
	            SOCKS5_RESPONSE_SUCCESS,
	            SOCKS5_RSV,
	            SOCKS5_ADDRTYPE_IPV4
	        };

	        if (SOCKS5_CMD_CONNECT != req->cmd) {
	            LOGW("not supported cmd: [%d]", req->cmd);
	            reply.rep = SOCKS5_RESPONSE_COMMAND_NOT_SUPPORTED;
	            goto _response_fail;
	        }

	        conn->remote_conn.addrtype = req->addrtype;
	        struct sockaddr_storage storage;
	        memset(&storage, 0, sizeof(struct sockaddr_storage));

	        LOGD("addrtype [%d]", req->addrtype);
	        switch (req->addrtype) {
	        case SOCKS5_ADDRTYPE_IPV4: {
	            if (readbytes < (sizeof(struct socks5_request) + 6)) {
	                LOGD("wait more data");
	                return;
	            }

	            struct sockaddr_in *addr = (struct sockaddr_in *)&storage;
	            addr->sin_family = AF_INET;

	            char *host = buf + sizeof(struct socks5_request);
	            char *port = host + 4;
	            memcpy(&addr->sin_addr.s_addr, host, 4);
	            memcpy(&addr->sin_port, port, 2);
	            conn->remote_conn.port = ntohs(addr->sin_port);

                    inet_ntop(AF_INET, &addr->sin_addr.s_addr, conn->remote_conn.hostname, SOCKADDR_STRLEN);
	            buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin_addr.s_addr, 4);
	            buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin_port, 2);
	            break;
	        }
	        case SOCKS5_ADDRTYPE_DOMAIN: {
	            // hostname length
	            if (readbytes < (sizeof(struct socks5_request) + 1)) {
	                LOGD("wait more data");
	                return;
	            }
	            int hostname_len = *(buff + sizeof(struct socks5_request));
	            if (readbytes < (sizeof(struct socks5_request) + hostname_len + 3)) {
	                LOGD("wait more data");
	                return;
	            }

	            memcpy(conn->remote_conn.hostname, buff + sizeof(struct socks5_request) + 1, hostname_len);

	            char *port = buff + sizeof(struct socks5_request) + 1 + hostname_len;
	            uint16_t sin_port;
	            memcpy(&sin_port, port, 2);
	            conn->remote_conn.port = ntohs(sin_port);

	            LOGI("remote hostname: [%s:%d]", conn->remote_conn.hostname, conn->remote_conn.port);

	            if (strtosockaddr(conn->remote_conn.hostname, (void *)&storage) > 0) {
	                if (storage.ss_family == AF_INET) {
	                    conn->remote_conn.addrtype = SOCKS5_ADDRTYPE_IPV4;
	                    struct sockaddr_in *addr = (struct sockaddr_in *)&storage;
	                    addr->sin_port = htons(conn->remote_conn.port);

	                    buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin_addr.s_addr, 4);
	                    buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin_port, 2);
	                } else if (storage.ss_family == AF_INET6) {
	                    conn->remote_conn.addrtype = SOCKS5_ADDRTYPE_IPV6;
	                    struct sockaddr_in6 *addr = (struct sockaddr_in6 *)&storage;
	                    addr->sin6_port = htons(conn->remote_conn.port);
	                    buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin6_addr, 16);
	                    buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin6_port, 2);
	                }
	            } else {
	                buffer_concat(conn->remote_conn.bndaddr, buff + sizeof(struct socks5_request), hostname_len + 3);
	            }
	            break;
	        }
	        case SOCKS5_ADDRTYPE_IPV6: {
	            if (readbytes < (sizeof(struct socks5_request) + 18)) {
	                LOGD("wait more data");
	                return;
	            }

	            struct sockaddr_in6 *addr = (struct sockaddr_in6 *)&storage;
	            addr->sin6_family = AF_INET6;

	            char *host = buff + sizeof(struct socks5_request);
	            char *port = host + 16;
	            memcpy(&addr->sin6_addr, host, 16);
	            memcpy(&addr->sin6_port, port, 2);
	            conn->remote_conn.port = ntohs(addr->sin6_port);

	            buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin6_addr, 16);
	            buffer_concat(conn->remote_conn.bndaddr, (char *)&addr->sin6_port, 2);
	            break;
	        }
	        default:
	            LOGW("not supported addrtype: [%d]", req->addrtype);
	            reply.rep = SOCKS5_RESPONSE_ADDRTYPE_NOT_SUPPORTED;
	            goto _response_fail;
    	    }

	    tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CONNECTING);
	    LOGI("remote hostname: [%s:%d]", conn->remote_conn.hostname, conn->remote_conn.port);
	    conn->remote = hloop_create_tcp_client(hevent_loop(io), conn->remote_conn.hostname, conn->remote_conn.port, remote_connect);
	    if (conn->remote == NULL) {
	        LOGW("remote connect failed [%s]", conn->remote_conn.hostname);
	        goto _response_fail;
	    }

	    hio_setcb_close(conn->remote, remote_close);
	    hio_setcb_read(conn->remote, remote_recv);
	    hio_set_connect_timeout(conn->remote, 5000);
	    hio_set_close_timeout(conn->remote, 5000);
	    hevent_set_userdata(conn->remote, conn);
	    break;
	_response_fail:
	    tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_CLOSING);
	    reply.rep = SOCKS5_RESPONSE_SERVER_FAILURE;
	    hio_write(io, (char *)&reply, sizeof(reply));
	    break;
	}
	case SOCKS5_CONN_STAGE_STREAM:
	    // send to remote
	    hio_write(conn->remote, buff, readbytes);
	    break;
        default:
            LOGW("unexpect stage [%d]", conn->stage);
            goto _close_conn;
    }
    LOGW("ok stage [%d] data", conn->stage);
	
    return;
_close_conn:
    LOGW("unexpect stage [%d] data", conn->stage);
    return;
}

void on_send(hio_t* io, void* buf, int writebytes)
{
    char localaddrstr[SOCKADDR_STRLEN] = {0};
    char peeraddrstr[SOCKADDR_STRLEN] = {0};
    LOGD("on_send connfd=%d [%s] <= [%s]", hio_fd(io),
            SOCKADDR_STR(hio_localaddr(io), localaddrstr),
            SOCKADDR_STR(hio_peeraddr(io), peeraddrstr));

    char *buff = (char *)buf;
    struct tunnel_conn *conn = hevent_userdata(io); 
    if (conn == NULL || buff == NULL) {
        LOGD("Error buf is null");
        return;
    }

    if (SOCKS5_CONN_STAGE_EXMETHOD == conn->stage) {
        // change stage after exchange method
        if (SOCKS5_AUTH_NOAUTH == conn->method) {
            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_EXHOST);
	} else if (SOCKS5_AUTH_USERNAMEPASSWORD == conn->method) {
            tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_USERNAMEPASSWORD);
        }
        // start receive new EXHOST/USERNAMEPASSWORD request
        hio_read(io);
    } else if (SOCKS5_CONN_STAGE_USERNAMEPASSWORD == conn->stage) {
        tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_EXHOST);
	// start receive EXHOST request
	hio_read(io);
    } else if (SOCKS5_CONN_STAGE_CONNECTED == conn->stage) {
        tunnel_conn_setstage(conn, SOCKS5_CONN_STAGE_STREAM);
        // start read real data
        hio_read(conn->client);
        hio_read(conn->remote);
    }

    LOGD("fd: [%d], stage: [%d]", hio_fd(io), conn->stage);

    // closing connection ?
    if (SOCKS5_CONN_STAGE_CLOSING == conn->stage) {
        goto _close_conn;
    }
    return;
_close_conn:
    LOGD("client_send_cb close conn, fd: [%d], stage: [%d]", hio_fd(io), conn->stage);
    tunnel_conn_close(conn);
}

void on_accept(hio_t* io)
{
    struct tunnel_conn *conn = NULL;
    if (io == NULL) {
        LOGE("on_accept fail");
        goto _close_conn;
    }

    char localaddrstr[SOCKADDR_STRLEN] = {0};
    char peeraddrstr[SOCKADDR_STRLEN] = {0};
    LOGD("accept connfd=%d [%s] <= [%s]", hio_fd(io),
            SOCKADDR_STR(hio_localaddr(io), localaddrstr),
            SOCKADDR_STR(hio_peeraddr(io), peeraddrstr));

    hio_setcb_close(io, on_close);
    hio_setcb_read(io, on_recv);
    hio_setcb_write(io, on_send);
    hio_read(io);

    conn = tunnel_conn_new();
    if (NULL == conn) {
        LOGE("socks5_conn_new fail: [%d]", errno);
        goto _close_conn;
    }

    conn->client = io;
    conn->server = hloop_userdata(hevent_loop(io));
    conn->loop = hevent_loop(io);
    hevent_set_userdata(io, conn);
    return;

_close_conn:
    if (NULL != conn) {
        tunnel_conn_close(conn);
    }
}

